# coding: utf-8

# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------

# Current Operation Coverage:
#   DatabaseVulnerabilityAssessmentRuleBaselines: 0/3
#   DatabaseVulnerabilityAssessments: 4/4
#   ServerVulnerabilityAssessments: 4/4
#   ManagedDatabaseVulnerabilityAssessmentRuleBaselines: 0/3
#   ManagedDatabaseVulnerabilityAssessments: 4/4
#   ManagedInstanceVulnerabilityAssessments: 4/4

import unittest

import azure.mgmt.sql
from devtools_testutils import AzureMgmtTestCase, RandomNameResourceGroupPreparer

AZURE_LOCATION = "eastus"


class MgmtSqlTest(AzureMgmtTestCase):

    def setUp(self):
        super(MgmtSqlTest, self).setUp()
        self.mgmt_client = self.create_mgmt_client(azure.mgmt.sql.SqlManagementClient)

        if self.is_live:
            from azure.mgmt.storage import StorageManagementClient

            self.storage_client = self.create_mgmt_client(StorageManagementClient)

    def create_blob_container(self, location, group_name, account_name, container_name):

        # StorageAccountCreate[put]
        BODY = {
            "sku": {"name": "Standard_GRS"},
            "kind": "StorageV2",
            "location": location,
            "encryption": {
                "services": {
                    "file": {"key_type": "Account", "enabled": True},
                    "blob": {"key_type": "Account", "enabled": True},
                },
                "key_source": "Microsoft.Storage",
            },
            "tags": {"key1": "value1", "key2": "value2"},
        }
        result = self.storage_client.storage_accounts.begin_create(group_name, account_name, BODY)
        storageaccount = result.result()

        # PutContainers[put]
        result = self.storage_client.blob_containers.create(group_name, account_name, container_name, {})

        # StorageAccountRegenerateKey[post]
        BODY = {"key_name": "key2"}
        key = self.storage_client.storage_accounts.regenerate_key(group_name, account_name, BODY)
        return key.keys[0].value

    @unittest.skip("hard to test")
    def test_managed_vulnerability_assessment(self):

        RESOURCE_GROUP = "testManagedInstance"
        MANAGED_INSTANCE_NAME = "testinstancexxy"
        DATABASE_NAME = "mydatabase"
        STORAGE_ACCOUNT_NAME = "mystorageaccountxydb"
        BLOB_CONTAINER_NAME = "myblobcontainer"
        VULNERABILITY_ASSESSMENT_NAME = "default"

        if self.is_live:
            ACCESS_KEY = self.create_blob_container(
                AZURE_LOCATION, RESOURCE_GROUP, STORAGE_ACCOUNT_NAME, BLOB_CONTAINER_NAME
            )
        else:
            ACCESS_KEY = "accesskey"

        # --------------------------------------------------------------------------
        # /ManagedDatabases/put/Creates a new managed database with minimal properties[put]
        # --------------------------------------------------------------------------
        BODY = {"location": AZURE_LOCATION}
        result = self.mgmt_client.managed_databases.begin_create_or_update(
            resource_group_name=RESOURCE_GROUP,
            managed_instance_name=MANAGED_INSTANCE_NAME,
            database_name=DATABASE_NAME,
            parameters=BODY,
        )
        result = result.result()

        # --------------------------------------------------------------------------
        # /ManagedInstanceVulnerabilityAssessments/put/Create a managed instance's vulnerability assessment with minimal parameters, when storageContainerSasKey is specified[put]
        # --------------------------------------------------------------------------
        BODY = {
            "storage_container_path": "https://"
            + STORAGE_ACCOUNT_NAME
            + ".blob.core.windows.net/"
            + BLOB_CONTAINER_NAME
            + "/",
            "storage_account_access_key": ACCESS_KEY,
        }
        result = self.mgmt_client.managed_instance_vulnerability_assessments.create_or_update(
            resource_group_name=RESOURCE_GROUP,
            managed_instance_name=MANAGED_INSTANCE_NAME,
            vulnerability_assessment_name=VULNERABILITY_ASSESSMENT_NAME,
            parameters=BODY,
        )

        # --------------------------------------------------------------------------
        # /ManagedDatabaseVulnerabilityAssessments/put/Create a database's vulnerability assessment with minimal parameters[put]
        # --------------------------------------------------------------------------
        BODY = {
            "storage_container_path": "https://"
            + STORAGE_ACCOUNT_NAME
            + ".blob.core.windows.net/"
            + BLOB_CONTAINER_NAME
            + "/",
            "storage_account_access_key": ACCESS_KEY,
        }
        result = self.mgmt_client.managed_database_vulnerability_assessments.create_or_update(
            resource_group_name=RESOURCE_GROUP,
            managed_instance_name=MANAGED_INSTANCE_NAME,
            database_name=DATABASE_NAME,
            vulnerability_assessment_name=VULNERABILITY_ASSESSMENT_NAME,
            parameters=BODY,
        )

        # --------------------------------------------------------------------------
        # /ManagedInstanceVulnerabilityAssessments/get/Get a managed instance's vulnerability assessment[get]
        # --------------------------------------------------------------------------
        result = self.mgmt_client.managed_instance_vulnerability_assessments.get(
            resource_group_name=RESOURCE_GROUP,
            managed_instance_name=MANAGED_INSTANCE_NAME,
            vulnerability_assessment_name=VULNERABILITY_ASSESSMENT_NAME,
        )

        # --------------------------------------------------------------------------
        # /ManagedDatabaseVulnerabilityAssessments/get/Get a database's vulnerability assessment[get]
        # --------------------------------------------------------------------------
        result = self.mgmt_client.managed_database_vulnerability_assessments.get(
            resource_group_name=RESOURCE_GROUP,
            managed_instance_name=MANAGED_INSTANCE_NAME,
            database_name=DATABASE_NAME,
            vulnerability_assessment_name=VULNERABILITY_ASSESSMENT_NAME,
        )

        # --------------------------------------------------------------------------
        # /ManagedInstanceVulnerabilityAssessments/get/Get a managed instance's vulnerability assessment policies[get]
        # --------------------------------------------------------------------------
        result = self.mgmt_client.managed_instance_vulnerability_assessments.list_by_instance(
            resource_group_name=RESOURCE_GROUP, managed_instance_name=MANAGED_INSTANCE_NAME
        )

        # --------------------------------------------------------------------------
        # /ManagedDatabaseVulnerabilityAssessments/get/Get a database's vulnerability assessments list[get]
        # --------------------------------------------------------------------------
        result = self.mgmt_client.managed_database_vulnerability_assessments.list_by_database(
            resource_group_name=RESOURCE_GROUP, managed_instance_name=MANAGED_INSTANCE_NAME, database_name=DATABASE_NAME
        )

        # --------------------------------------------------------------------------
        # /ManagedInstanceVulnerabilityAssessments/delete/Remove a managed instance's vulnerability assessment[delete]
        # --------------------------------------------------------------------------
        result = self.mgmt_client.managed_instance_vulnerability_assessments.delete(
            resource_group_name=RESOURCE_GROUP,
            managed_instance_name=MANAGED_INSTANCE_NAME,
            vulnerability_assessment_name=VULNERABILITY_ASSESSMENT_NAME,
        )

        # --------------------------------------------------------------------------
        # /ManagedDatabaseVulnerabilityAssessments/delete/Remove a database's vulnerability assessment[delete]
        # --------------------------------------------------------------------------
        result = self.mgmt_client.managed_database_vulnerability_assessments.delete(
            resource_group_name=RESOURCE_GROUP,
            managed_instance_name=MANAGED_INSTANCE_NAME,
            database_name=DATABASE_NAME,
            vulnerability_assessment_name=VULNERABILITY_ASSESSMENT_NAME,
        )

        # --------------------------------------------------------------------------
        # /ManagedDatabases/delete/Delete managed database[delete]
        # --------------------------------------------------------------------------
        result = self.mgmt_client.managed_databases.begin_delete(
            resource_group_name=RESOURCE_GROUP, managed_instance_name=MANAGED_INSTANCE_NAME, database_name=DATABASE_NAME
        )
        result = result.result()

    @unittest.skip("hard to test")
    @RandomNameResourceGroupPreparer(location=AZURE_LOCATION)
    def test_vulnerability_assessment(self, resource_group):

        RESOURCE_GROUP = resource_group.name
        SERVER_NAME = "myserverxpxyz"
        DATABASE_NAME = "mydatabase"
        VULNERABILITY_ASSESSMENT_NAME = "default"
        STORAGE_ACCOUNT_NAME = "mystorageaccountxye"
        BLOB_CONTAINER_NAME = "myblobcontainer"
        SECURITY_ALERT_POLICY_NAME = "default"
        RULE_ID = "VA1001"
        BASELINE_NAME = "default"

        if self.is_live:
            ACCESS_KEY = self.create_blob_container(
                AZURE_LOCATION, RESOURCE_GROUP, STORAGE_ACCOUNT_NAME, BLOB_CONTAINER_NAME
            )
        else:
            ACCESS_KEY = "accesskey"

        # --------------------------------------------------------------------------
        # /Servers/put/Create server[put]
        # --------------------------------------------------------------------------
        BODY = {
            "location": AZURE_LOCATION,
            "administrator_login": "dummylogin",
            "administrator_login_password": "Un53cuRE!",
            "version": "12.0",
        }
        result = self.mgmt_client.servers.begin_create_or_update(
            resource_group_name=RESOURCE_GROUP, server_name=SERVER_NAME, parameters=BODY
        )
        result = result.result()

        # --------------------------------------------------------------------------
        # /ServerSecurityAlertPolicies/put/Update a server's threat detection policy with minimal parameters[put]
        # --------------------------------------------------------------------------
        BODY = {"state": "Enabled", "email_account_admins": True, "disabled_alerts": [], "email_addresses": []}
        result = self.mgmt_client.server_security_alert_policies.begin_create_or_update(
            resource_group_name=RESOURCE_GROUP,
            server_name=SERVER_NAME,
            security_alert_policy_name=SECURITY_ALERT_POLICY_NAME,
            parameters=BODY,
        )
        result = result.result()

        # --------------------------------------------------------------------------
        # /ServerVulnerabilityAssessments/put/Create a server's vulnerability assessment with minimal parameters, when storageAccountAccessKey is specified[put]
        # --------------------------------------------------------------------------
        BODY = {
            "storage_container_path": "https://"
            + STORAGE_ACCOUNT_NAME
            + ".blob.core.windows.net/"
            + BLOB_CONTAINER_NAME
            + "/",
            "storage_account_access_key": ACCESS_KEY,
        }
        result = self.mgmt_client.server_vulnerability_assessments.create_or_update(
            resource_group_name=RESOURCE_GROUP,
            server_name=SERVER_NAME,
            vulnerability_assessment_name=VULNERABILITY_ASSESSMENT_NAME,
            parameters=BODY,
        )

        # --------------------------------------------------------------------------
        # /Databases/put/Creates a database [put]
        # --------------------------------------------------------------------------
        BODY = {"location": AZURE_LOCATION}
        result = self.mgmt_client.databases.begin_create_or_update(
            resource_group_name=RESOURCE_GROUP, server_name=SERVER_NAME, database_name=DATABASE_NAME, parameters=BODY
        )
        result = result.result()

        # --------------------------------------------------------------------------
        # /DatabaseVulnerabilityAssessments/put/Create a database's vulnerability assessment with minimal parameters, when storageAccountAccessKey is specified[put]
        # --------------------------------------------------------------------------
        BODY = {
            "storage_container_path": "https://"
            + STORAGE_ACCOUNT_NAME
            + ".blob.core.windows.net/"
            + BLOB_CONTAINER_NAME
            + "/",
            "storage_account_access_key": ACCESS_KEY,
        }
        result = self.mgmt_client.database_vulnerability_assessments.create_or_update(
            resource_group_name=RESOURCE_GROUP,
            server_name=SERVER_NAME,
            database_name=DATABASE_NAME,
            vulnerability_assessment_name=VULNERABILITY_ASSESSMENT_NAME,
            parameters=BODY,
        )

        # --------------------------------------------------------------------------
        # /DatabaseVulnerabilityAssessmentRuleBaselines/put/Creates or updates a database's vulnerability assessment rule baseline.[put]
        # --------------------------------------------------------------------------
        BODY = {
            "baseline_results": [
                {"result": ["userA", "SELECT"]},
                {"result": ["userB", "SELECT"]},
                {"result": ["userC", "SELECT", "tableId_4"]},
            ]
        }
        # result = self.mgmt_client.database_vulnerability_assessment_rule_baselines.create_or_update(resource_group_name=RESOURCE_GROUP, server_name=SERVER_NAME, database_name=DATABASE_NAME, vulnerability_assessment_name=VULNERABILITY_ASSESSMENT_NAME, rule_id=RULE_ID, baseline_name=BASELINE_NAME, parameters=BODY)

        # --------------------------------------------------------------------------
        # /ServerVulnerabilityAssessments/get/Get a server's vulnerability assessment[get]
        # --------------------------------------------------------------------------
        result = self.mgmt_client.server_vulnerability_assessments.get(
            resource_group_name=RESOURCE_GROUP,
            server_name=SERVER_NAME,
            vulnerability_assessment_name=VULNERABILITY_ASSESSMENT_NAME,
        )

        # --------------------------------------------------------------------------
        # /DatabaseVulnerabilityAssessments/get/Get a database's vulnerability assessment[get]
        # --------------------------------------------------------------------------
        result = self.mgmt_client.database_vulnerability_assessments.get(
            resource_group_name=RESOURCE_GROUP,
            server_name=SERVER_NAME,
            database_name=DATABASE_NAME,
            vulnerability_assessment_name=VULNERABILITY_ASSESSMENT_NAME,
        )

        # --------------------------------------------------------------------------
        # /DatabaseVulnerabilityAssessmentRuleBaselines/get/Gets a database's vulnerability assessment rule baseline.[get]
        # --------------------------------------------------------------------------
        # result = self.mgmt_client.database_vulnerability_assessment_rule_baselines.get(resource_group_name=RESOURCE_GROUP, server_name=SERVER_NAME, database_name=DATABASE_NAME, vulnerability_assessment_name=VULNERABILITY_ASSESSMENT_NAME, rule_id=RULE_ID, baseline_name=BASELINE_NAME)

        # --------------------------------------------------------------------------
        # /ServerVulnerabilityAssessments/get/Get a server's vulnerability assessment policies[get]
        # --------------------------------------------------------------------------
        result = self.mgmt_client.server_vulnerability_assessments.list_by_server(
            resource_group_name=RESOURCE_GROUP, server_name=SERVER_NAME
        )

        # --------------------------------------------------------------------------
        # /DatabaseVulnerabilityAssessments/get/Get the database's vulnerability assessment policies[get]
        # --------------------------------------------------------------------------
        result = self.mgmt_client.database_vulnerability_assessments.list_by_database(
            resource_group_name=RESOURCE_GROUP, server_name=SERVER_NAME, database_name=DATABASE_NAME
        )

        # --------------------------------------------------------------------------
        # /DatabaseVulnerabilityAssessmentRuleBaselines/delete/Removes a database's vulnerability assessment rule baseline.[delete]
        # --------------------------------------------------------------------------
        # result = self.mgmt_client.database_vulnerability_assessment_rule_baselines.delete(resource_group_name=RESOURCE_GROUP, server_name=SERVER_NAME, database_name=DATABASE_NAME, vulnerability_assessment_name=VULNERABILITY_ASSESSMENT_NAME, rule_id=RULE_ID, baseline_name=BASELINE_NAME)

        # --------------------------------------------------------------------------
        # /DatabaseVulnerabilityAssessments/delete/Remove a database's vulnerability assessment[delete]
        # --------------------------------------------------------------------------
        result = self.mgmt_client.database_vulnerability_assessments.delete(
            resource_group_name=RESOURCE_GROUP,
            server_name=SERVER_NAME,
            database_name=DATABASE_NAME,
            vulnerability_assessment_name=VULNERABILITY_ASSESSMENT_NAME,
        )

        # --------------------------------------------------------------------------
        # /Databases/delete/Deletes a database.[delete]
        # --------------------------------------------------------------------------
        result = self.mgmt_client.databases.begin_delete(
            resource_group_name=RESOURCE_GROUP, server_name=SERVER_NAME, database_name=DATABASE_NAME
        )
        result = result.result()

        # --------------------------------------------------------------------------
        # /ServerVulnerabilityAssessments/delete/Remove a server's vulnerability assessment[delete]
        # --------------------------------------------------------------------------
        result = self.mgmt_client.server_vulnerability_assessments.delete(
            resource_group_name=RESOURCE_GROUP,
            server_name=SERVER_NAME,
            vulnerability_assessment_name=VULNERABILITY_ASSESSMENT_NAME,
        )

        # --------------------------------------------------------------------------
        # /Servers/delete/Delete server[delete]
        # --------------------------------------------------------------------------
        result = self.mgmt_client.servers.begin_delete(resource_group_name=RESOURCE_GROUP, server_name=SERVER_NAME)
        result = result.result()
